package com.yeskery.nut.extend.mybatis;

import com.yeskery.nut.transaction.Transaction;
import com.yeskery.nut.transaction.TransactionManager;
import org.apache.ibatis.session.SqlSession;
import org.apache.ibatis.session.SqlSessionFactory;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.sql.Connection;
import java.util.Optional;

/**
 * MyBatis Mapper 代理对象执行器
 * @author sprout
 * @version 1.0
 * 2022-07-07 21:38
 */
public class MapperProxyInvocationHandler implements InvocationHandler {

    /** 事务管理器 */
    private final TransactionManager transactionManager;

    /** SQL会话工厂 */
    private final SqlSessionFactory sqlSessionFactory;

    /** Mapper类型 */
    private final Class<?> mapperClass;

    /**
     * 构建MyBatis Mapper 代理对象执行器
     * @param transactionManager 事务管理器
     * @param sqlSessionFactory SQL会话工厂
     * @param mapperClass Mapper类型
     */
    public MapperProxyInvocationHandler(TransactionManager transactionManager, SqlSessionFactory sqlSessionFactory, Class<?> mapperClass) {
        this.transactionManager = transactionManager;
        this.sqlSessionFactory = sqlSessionFactory;
        this.mapperClass = mapperClass;
    }

    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        if (transactionManager == null) {
            return doTransactionInvoke(null, method, args);
        }
        Optional<Transaction> optional = transactionManager.getCurrentTransactionOptional();
        if (!optional.isPresent()) {
            return doTransactionInvoke(null, method, args);
        }
        return doTransactionInvoke(optional.get().getConnection(), method, args);
    }

    /**
     * 执行事务方法
     * @param connection 连接对象
     * @param method 方法
     * @param args 方法参数
     * @return 方法结果
     * @throws Throwable 异常
     */
    private Object doTransactionInvoke(Connection connection, Method method, Object[] args) throws Throwable {
        if (connection == null) {
            try (SqlSession sqlSession = sqlSessionFactory.openSession(true)) {
                Object mapper = sqlSession.getMapper(mapperClass);
                Object result = method.invoke(mapper, args);
                connection = sqlSession.getConnection();
                connection.close();
                return result;
            }
        } else {
            SqlSession sqlSession = transactionManager.getTransactionResource(connection, SqlSession.class);
            if (sqlSession == null) {
                sqlSession = sqlSessionFactory.openSession(connection);
                transactionManager.putTransactionResource(connection, sqlSession);
            }
            Object mapper = sqlSession.getMapper(mapperClass);
            return method.invoke(mapper, args);
        }
    }
}
